-
Notifications
You must be signed in to change notification settings - Fork 68
Added support for Column Major C [Bias] in GEMM #656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
b0e4187 to
31cdc64
Compare
Signed-off-by: Maiti, Kausik <[email protected]>
31cdc64 to
d19437b
Compare
| std::cout << "\n\nRunning BMG GEMM with bfloat16, RowMajor Bias and bfloat16, RowMajor Output" << std::endl << std::flush; | ||
| test_bmg_gemm<bfloat16_t, cutlass::layout::RowMajor, bfloat16_t>(options, hw_info); | ||
| std::cout << "\n\nRunning BMG GEMM with bfloat16, ColumnMajor Bias and bfloat16, RowMajor Output" << std::endl << std::flush; | ||
| test_bmg_gemm<bfloat16_t, cutlass::layout::ColumnMajor, bfloat16_t>(options, hw_info); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the 00_bmg_gemm example is the right place to test all these options -- as the very first example, I think it should probably be the simplest one. These feels like something that belongs as a test, or maybe we could have a separate 00_bmg_gemm_bias or 00_bmg_gemm_with_beta executable?
|
|
||
| using ActualGmemTiledCopyC = replace_void_t<CopyOpG2R, DefaultCopyOpG2R>; | ||
| constexpr bool IsColMajorC = cutlass::gemm::detail::is_major<0, StrideC>(); | ||
| using ActualGmemTiledCopyC = replace_void_t<CopyOpG2R, std::conditional_t<IsColMajorC, CopyOpG2RTransposed, DefaultCopyOpG2R>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest refactoring this a bit to make the logic/naming clearer:
using DefaultCopyOpR2GNontranspose = XE_STORE_2D<CopyBitsD, cute::gcd(8, get<0>(EpilogueTile{})), cute::gcd(512 / CopyBitsD, get<1>(EpilogueTile{}))>;
using DefaultCopyOpR2GTranspose = XE_LOAD_2D_TRANSPOSE<CopyBitsCTranspose, cute::gcd(512 / CopyBitsC, get<1>(EpilogueTile{})), cute::gcd(8 / Sub32BitFactor, get<0>(EpilogueTile{}))>;
using DefaultCopyOpR2G = conditional_t<IsColMajorC, DefaultCopyOpR2GTranspose, DefaultCopyOpR2GNontranspose>;
using ActualGmemTiledCopyC = replace_void_t<CopyOpG2R, DefaultCopyOpG2R>;
Added support for Column Major C [Bias] in GEMM.
Following changes have been made.
To do: